{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "120e57f5",
   "metadata": {},
   "source": [
    "Model surgery\n",
    "==============================\n",
    "\n",
    "Usually, Flax modules and optimizers track and update the params for you. But there may be some time when you want to do some model surgery and tweak the param tensors yourself. This guide shows you how to do the trick."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c3bfb0e",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "413f8b2d",
   "metadata": {
    "tags": [
     "skip-execution"
    ]
   },
   "outputs": [],
   "source": [
    "!pip install --upgrade -q pip jax jaxlib flax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b002c8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import functools\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from flax import traverse_util\n",
    "from flax import linen as nn\n",
    "from flax.core import freeze\n",
    "import jax\n",
    "import optax"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1060b519",
   "metadata": {},
   "source": [
    "Surgery with Flax Modules\n",
    "--------------------------------\n",
    "\n",
    "Let's create a small convolutional neural network model for our demo.\n",
    "\n",
    "As usual, you can run `CNN.init(...)['params']` to get the `params` to pass and modify it in every step of your training.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "755ae323",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    @nn.compact\n",
    "    def __call__(self, x):\n",
    "      x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n",
    "      x = nn.relu(x)\n",
    "      x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
    "      x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n",
    "      x = nn.relu(x)\n",
    "      x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
    "      x = x.reshape((x.shape[0], -1))\n",
    "      x = nn.Dense(features=256)(x)\n",
    "      x = nn.relu(x)\n",
    "      x = nn.Dense(features=10)(x)\n",
    "      x = nn.log_softmax(x)\n",
    "      return x\n",
    "\n",
    "def get_initial_params(key):\n",
    "    init_shape = jnp.ones((1, 28, 28, 1), jnp.float32)\n",
    "    initial_params = CNN().init(key, init_shape)['params']\n",
    "    return initial_params\n",
    "\n",
    "key = jax.random.key(0)\n",
    "params = get_initial_params(key)\n",
    "\n",
    "jax.tree_util.tree_map(jnp.shape, params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "170273f8",
   "metadata": {},
   "source": [
    "Note that what returned as `params` is a `FrozenDict`, which contains a few JAX arrays as kernel and bias. \n",
    "\n",
    "A `FrozenDict` is nothing more than a read-only dict, and Flax made it read-only because of the functional nature of JAX: JAX arrays are immutable, and the new `params` need to replace the old `params`. Making the dict read-only ensures that no in-place mutation of the dict can happen accidentally during the training and updating.\n",
    "\n",
    "One way to actually modify the params outside of a Flax module is to explicitly flatten it and creates a mutable dict. Note that you can use a separator `sep` to join all nested keys. If no `sep` is given, the key will be a tuple of all nested keys."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7ec7741",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get a flattened key-value list.\n",
    "flat_params = traverse_util.flatten_dict(params, sep='/')\n",
    "\n",
    "jax.tree_util.tree_map(jnp.shape, flat_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2adda656",
   "metadata": {},
   "source": [
    "Now you can do whatever you want with the params. When you are done, unflatten it back and use it in future training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb975feb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Somehow modify a layer\n",
    "dense_kernel = flat_params['Dense_1/kernel']\n",
    "flat_params['Dense_1/kernel'] = dense_kernel / jnp.linalg.norm(dense_kernel)\n",
    "\n",
    "# Unflatten.\n",
    "unflat_params = traverse_util.unflatten_dict(flat_params, sep='/')\n",
    "# Refreeze.\n",
    "unflat_params = freeze(unflat_params)\n",
    "jax.tree_util.tree_map(jnp.shape, unflat_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3462cd8",
   "metadata": {},
   "source": [
    "Surgery with Optimizers\n",
    "--------------------------------\n",
    "\n",
    "When using `Optax` as an optimizer, the ``opt_state`` is actually a nested tuple\n",
    "of the states of individual gradient transformations that compose the optimizer.\n",
    "These states contain pytrees that mirror the parameter tree, and can be modified\n",
    "the same way: flattening, modifying, unflattening, and then recreating a new\n",
    "optimizer state that mirrors the original state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cbecb63",
   "metadata": {},
   "outputs": [],
   "source": [
    "tx = optax.adam(1.0)\n",
    "opt_state = tx.init(params)\n",
    "\n",
    "# The optimizer state is a tuple of gradient transformation states.\n",
    "jax.tree_util.tree_map(jnp.shape, opt_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18f1cebb",
   "metadata": {},
   "source": [
    "The pytrees inside the optimizer state follow the same structure as the\n",
    "parameters and can be flattened / modified exactly the same way."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13b5e25f",
   "metadata": {},
   "outputs": [],
   "source": [
    "flat_mu = traverse_util.flatten_dict(opt_state[0].mu, sep='/')\n",
    "flat_nu = traverse_util.flatten_dict(opt_state[0].nu, sep='/')\n",
    "\n",
    "jax.tree_util.tree_map(jnp.shape, flat_mu)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5c4479e",
   "metadata": {},
   "source": [
    "After modification, re-create optimizer state. Use this for future training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dcac8cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_state = (\n",
    "    opt_state[0]._replace(\n",
    "        mu=traverse_util.unflatten_dict(flat_mu, sep='/'),\n",
    "        nu=traverse_util.unflatten_dict(flat_nu, sep='/'),\n",
    "    ),\n",
    ") + opt_state[1:]\n",
    "jax.tree_util.tree_map(jnp.shape, opt_state)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "md,ipynb",
   "main_language": "python"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
